{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MNIST with SciKit-Learn and skorch\n",
    "\n",
    "This notebooks shows how to define and train a simple Neural-Network with PyTorch and use it via skorch with SciKit-Learn.\n",
    "\n",
    "<table align=\"left\"><td>\n",
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/MNIST.ipynb\">\n",
    "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>  \n",
    "</td><td>\n",
    "<a target=\"_blank\" href=\"https://github.com/skorch-dev/skorch/blob/master/notebooks/MNIST.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note**: If you are running this in [a colab notebook](https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/MNIST.ipynb), we recommend you enable a free GPU by going:\n",
    "\n",
    "> **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**\n",
    "\n",
    "If you are running in colab, you should install the dependencies and download the dataset by running the following cell:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "! [ ! -z \"$COLAB_GPU\" ] && pip install torch scikit-learn==0.20.* skorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import fetch_openml\n",
    "from sklearn.model_selection import train_test_split\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading Data\n",
    "Using SciKit-Learns ```fetch_openml``` to load MNIST data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist = fetch_openml('mnist_784', cache=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(70000, 784)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mnist.data.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocessing Data\n",
    "\n",
    "Each image of the MNIST dataset is encoded in a 784 dimensional vector, representing a 28 x 28 pixel image. Each pixel has a value between 0 and 255, corresponding to the grey-value of a pixel.<br />\n",
    "The above ```featch_mldata``` method to load MNIST returns ```data``` and ```target``` as ```uint8``` which we convert to ```float32``` and ```int64``` respectively."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = mnist.data.astype('float32')\n",
    "y = mnist.target.astype('int64')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To avoid big weights that deal with the pixel values from between [0, 255], we scale `X` down. A commonly used range is [0, 1]."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "X /= 255.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.0, 1.0)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.min(), X.max()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note: data is not normalized."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert(X_train.shape[0] + X_test.shape[0] == mnist.data.shape[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((52500, 784), (52500,))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train.shape, y_train.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Print a selection of training images and their labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_example(X, y):\n",
    "    \"\"\"Plot the first 5 images and their labels in a row.\"\"\"\n",
    "    for i, (img, y) in enumerate(zip(X[:5].reshape(5, 28, 28), y[:5])):\n",
    "        plt.subplot(151 + i)\n",
    "        plt.imshow(img)\n",
    "        plt.xticks([])\n",
    "        plt.yticks([])\n",
    "        plt.title(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAABbCAYAAABEQP/sAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAFBhJREFUeJztnXl8FFW2x7+3EyCEfScsEhCiiCg+EBRU3JUIiAMuuOEGIuJze6KiDg/HUVBxVBQVQURgnuICsimKjiKyDPAAUUE2g8hOBILEBJK+88ep6tChExJJp6qT8/18+HRSdas5fVN963fPPfccY61FURRF8Z6A1wYoiqIogg7IiqIoPkEHZEVRFJ+gA7KiKIpP0AFZURTFJ+iArCiK4hN0QFYURfEJvh2QjTGTjTHbjTEZxph1xpg7vLbJS4wxlYwx440xm40xB4wxK4wx3by2y2uMMcnGmDnGmL3GmB3GmFeMMfFe2+UVep8UTCyMKb4dkIFngGRrbXWgJ/CUMaa9xzZ5STywBegK1ACeAKYaY5I9tMkPjAF2AUlAO6R/BnlqkbfofVIwvh9TfDsgW2t/sNZmu786/0700CRPsdYetNb+r7U2zVobtNbOAn4GfHVDeUBzYKq1NstauwP4FGjjsU2eofdJwcTCmOLbARnAGDPGGJMJrAW2A3M8Nsk3GGMaACnAD17b4jEvAdcZYxKNMY2BbsigrKD3SX78Pqb4ekC21g4CqgHnAh8B2YVfUT4wxlQApgATrbVrvbbHY75GFHEG8CuwDJjuqUU+Qe+To/H7mOLrARnAWptrrV0ANAHu8toerzHGBIBJwCFgsMfmeIrTF3ORL1YVoC5QCxjppV1+QO+TgvHzmOL7AfkI4vGZv6e0McYYYDzQAOhtrT3ssUleUxtoCrxirc221qYDE4BUb83yFr1PiozvxhRfDsjGmPrGmOuMMVWNMXHGmMuAvsCXXtvmMa8BrYEe1to/vDbGa6y1e5AFq7uMMfHGmJpAP2CVt5Z5jt4n+YiVMcX4MR+yMaYe8AFwOvLQ2Ay8bK1901PDPMQY0wxIQ3xeOUecutNaO8UTo3yAMaYd8CJyr+QC/wLuttbu8tQwj9D7JDKxMqb4ckBWFEUpj/jSZaEoilIe0QFZURTFJ+iArCiK4hN0QFYURfEJOiAriqL4hGKlKaxoKtkEqkTLFl+QxUEO2WxT1PbloU8ADrB3j7W2XlHaap9Epjz0i35/IlPUe6VYA3ICVehkLvrzVsUAS+wXxWpfHvoEYJ79YHNR22qfRKY89It+fyJT1HtFXRaKoig+QQdkRVEUn6ADsqIoik8ot7XHYoFAYiIA+3ucBsCB6zIAsFbWTLo02QTAz/elAGAWlvecOoqSx857OgMwaJCkx761+hYATvn6dgCSplYEoPL0f3tgXWRUISuKovgE3yhk2/l0ADZcnwDAFz1GAfD2vk4AzN/VEoC0TfUBSJ4m1yV88yMAwYMHS83WaLNufAcAJl8wFoCOlb4B4JccyaQ4eV9HAKrFZQFw15SvAHjs7J4A5OzYWWq2+o34xo3kByOziJxftwLw221nA/DEoxMB6FklE4BcGwxd22WI1EatMWVxqdhaGpgzpLzg+gdEDa67cHzY+fEZTQD4uJvcczlpv5SidSWDaS+f8bcnD4UdX9puNABBgs6r8H1XSfD23KltAZjRv23E9/19kUSpNZ+QBkDO1m0lZnNBqEJWFEXxCZ4r5LT3xD/6XkdRgxsPy1NpzsHWANSvIH7TeaeIJA6e4qQL7S4vA7d0BWDhp+IvSh61WtodOBBly6NHjZWiZh6dMRCASuny5I/PEEUc/E7Ko8XVaQhAo4V7Adh6jRQ/aPBy+VPI8U0aA3DidEmD3LnaBgAWZIh/vXuttwC4qLIo48MRss7OGiGzspu+vwOA4Ko10TM4SpgKcu9sv1sU77v3PQ9AywqVgDyV6OL6VUcMvxyAlNtFBdqcHGKFrPqy1vJNu7fznSlcbz5cR+q+PlRndcTzgXZy/a2pEie99+pGoXPRUsuqkBVFUXyC5wr5shNFhfR9+34Akl+IrHBndhAlvO/kqgDcNHQ2AK83/RqAQP/5AFzU+S8AVO6dd22sqeUGoxdGPJ5f3eSm/wbA48t6AVA1mkb5nJ+elZnV9KQZYcd7V90DQKaVWca1G8XPvnqRrEl8dt1zobZN4isDkPa4fC2S+4nyCmZmRsvsEuen0e0AWNdD/KcBZE0mSOGFKNZfPA6AK9rdLAeWfR8lC6PHuP0tAHj+G1H79RbK37HuosKLx2SdUBOAn68Wfdql7XoAJjSTXYfjm30OwOVtB4WuqagKWVEUpWzjuUJef04cACdkiSrMrwJdrPPErrFMfp89S56GH7e+EID4p3cD8HmbDwG4fNpVoWsrdpeiu8GsrJIz3EdUXSLKrkUfebIffNlLa0qXtL9J9MSYjuMinp+dWQOAoRNF+TV9Su6zZhc0BWDb1Ymhtk3iRUX2b/0tAPMqJ8sJHytk12ecOUt86D+0ecU5ExfW7os/5HPe+95tADT6RnzEk8b+A4CkODm/86zqANRfFj2bS5pKnywFYMYndQBIYWnY+dxjXF9hnbymzJPXxaPOkuPJXwGR1xuihSpkRVEUn6ADsqIoik/w3GXxZ90Iufv2A2AWyXZhe4VMuUYuliDxz1pPD7Xt2UgW+oKb0v6smb6mVg/Z/LBytbhxWrHbS3NKhezUMwFYfqtMuSuZCmHnNxzOBuCxCY6r4unwhdJbX/8YgI6V8uajvwflmtELLgYgJd0/W2oLYt0LZwDw06ljnCPhropvs0VzjbqlLwDJCxaFnf/bjksAGNNY3DT7ThFXRv2oWOtP4urUBmDrTScD8OXVzwJw2IorcOz+ZAASV28NXROtoEBVyIqiKD7Bc4VcUgTqiUM/KyghbkeG+eTWKNsVCbbskie8KXqhhpglvb8s4rXvvxI4Whl/klkNgNf63AhAk1WijOObySJe5jjRIJckfutckRC6NvX7mwBIGeh/Zewy5KJZAAQI/9u7ynjA5LsAaJZPGbvEGeu8Otqs7N9CIbK6SwqC1Gf+BcB9tec6Z2QTzc5cSVXw9vOyC6321sh9WJKoQlYURfEJMauQ4+qKIs7sJH7Tzk+L4hlaVzaWXLvx8lBbs0ESppRi9EqpsGeAqMV/dpY4tyGDBxXWPKYJnC5b6f86RJIDdUuMvNln6GoJdzxhh/jR178kIUwPXiIbiRpXkG3mtQKijN0NIwBZMxsAUJ2NJWp7NMl1NJU7I3TD21yfcUHKOHS9k8o1lGSprH1JjmDXIEmvcNvdci+kVn0BgCbxlSK27zP0fwCoPTn6ythFFbKiKIpPiDmFvP0Becp1v2kBAMPrfwbk+dBazr0TgJMGrAxdE0uJUopD78FfAjB+93kAVJq9tLDmMc2+NrLBoyBl7LKi0zvyw/8X/n4zDtYCYPgbN4aOJY2JvGU9Fui8QhRxXWfbd2DlysKahyILTq0SPhtI/DXmhoQCiWsgsSI/PSKz6LXXjA47H0BmE3uD4isetFm21f9+pUwTaqSXfhpWVciKoig+wXePQ3crKKe1AmDjNbKV880+bwBwbkK49Jl0IAmAN56UWOOUf8pTrQy7wtj53zJLuLmmxEv2vfdBABJZ4plN0cJUEv/egT4lkyDKjTUee7MkZEpaHLuqGGD2BeJbr7tf1kmOFdfvKuOEaRKvPLCmlAHb7kQU1NhYUPIC/+J+psyOkn52x63SBw+0leRA/aqLz/joTyZ69NzFkua2wVuyrlD5jx+iaG3hqEJWFEXxCZ4p5LjqonzXjJDdMd3O/A6A2hWkFNOweu+EtXd9xO5qcqenBgOQ9InsnqmeVnbK7hyLhr02A/DAL1cCkDitDCrjDqcCMOoDKbeTUuH4lOz8LJl5jTrLib7Z/d1xvZ9fyN1ZeGpJl8MXtweg+0uSQcdVxi4XL3bild+Lve+Rq4znjhsTdjwQikApnJWdpXgBMvFk+C7pq/9bJBE6rV9MByB3XfSjb1QhK4qi+ATPFLK1onT7nCWRASMaLAfyFPCwXbJHf/ch2XlVJV58f881FDWYfbGUduKT0rHXD/z8jMQdL2wlZXm6jnkIgCbEth/UJa51q9DPL34oawbN4xMKal4szkuQeOO7XpYyPM37lu18H/HJJwDw42MSW70h9Y18LWTG6e7oSx4pOrIsrb3cullKL327ulXE8307ylgyrP7ysOPu78OvXAHA1W1SAfija1TMDEMVsqIoik/wTCG7ZZW+7ywKqGed7mHnXd+YzRGfcqCK5KN4Z5kk4l55lviYewV6U15Y1+81AFrMuweAVs+UEWXsxIu2npznoytIGa84JEpu8I/Xhx3fky4zqaSZkttiu1OU4KeL3wxr16jO/hKw2L8cuE78nqlDvwJgRh3Jepi/hNNV6+X7lv24KOjA8sLjlv2Mm6C+Z+Mz853ZBxydsN5luaNHeyLXmfaSKXLbE3KPfXyG3Dvvt5wDwIQ1TUPXTj9f1jiK6sMvKqqQFUVRfILncchu3GTwGEUDgwdFKf/93+LPucUpyrivvTzhq5bRXMcAWx+W5d+Nh2V3YtOpnv/ZSpSfHnXKcTUseEFgwJbzAVg1UZRJvdfC8wvUytfepraP+D6/fSo+5CTSim+oD9l/gyjihH47AJjRehQANQLuDCM8fduD26V9sLesyQTSY1cZlzR2ucQfJ0mIOhe+KYWX16ZK9Ea/6ptDbafXkPUcVCEriqKUTWJGarn5bLu3kWxurk+syq9ls3ApQFyKxFcO7ifVLXq9OgSARrPKhu/YpVuXFQWe+zlH/r47Bsrfv97KwjNvbXlCZhN92y0IO55txadcc1Ns5zVx/e0tZokvfETSi8CReaHDfe9fZcnxIc8OAKD+ROnrYNZv0Tb1uAh2lSirjddI/PgJs8Sv6/qLyyqqkBVFUXxCzCjkLX1EIU1PklVjN045sFT8PmUpfjK+ofjF+8+WCgYZuaJ6Gj1btpRxURi1U+rbBVf+GLlBQHIybHuwEwAf3i4x2i0rhOe4nZMpfVp5euxUA4nEmmHJAMxs9DoAQcIrpjyyQyIGZs6V/mg1Vnay1k1b5LSPDTbcKEPT2tRXANjZQ3zeF06VHMXN5hwOtY3/cjklycYpMrbMO+cfzhG5l8btb5HXaG9Gif6fLqqQFUVRfILvFXJ8i2QAHuj/AQC/5EhWqgXDZbW4ck5sK55IbLlefMc9EiXq4Lz7JWdvVWIvz0BhxDdvBsClNT8vsM3nS04DoJWTyS6+sURJbOuVDEC/QRIjendNN9dt5OoPj30kccstKL3qDyWFObNt6OcNV4oyDtXAcyp9nD30bgBqTZTP19z5nAV5zN1cMl2/3Q7A5IlSfbrR8/6YhQUOyszHzUeRFCcVoNf0fVWO982LHnHXkzouuwGA+n+XWUPcRpkd5KYX7i+3XdrJ65OSs2LNyeOc/1vyJX/o5M6e0z0vcid3d1qxP1NRUIWsKIriE0pUIbtP3WBmJnB8lTriTmoJwJpHpVLELdUl3i/la1ECzWPcFxiJQ5d1AGDWfZLnOIiogldHvgRAxghRf0Eb+TkaMMGw8z9ky67G2efKXv5jKYXSJlhdFMiJFdKdIxWPavPtlRJX+5eW/QB4v43U1GvgKKZj0Wm5KOMTh0ke7Vhca9h9RtXQz6Edd44yDuV+eWICAI/0krzg2U50BT8nRnzPRh1EGT9QWyouT7GXlKzRx8lJf5U1g/NOugaAr057N1+LvO9A0PGML+4wSQ5Mk5fn0mVmMWNL27Ar931XF4Cap+0B4PGUqQBclrjfeT/nbQ5KuyffEuXdeFP0Zw+qkBVFUXxCiShkVxn/13yp6LvsjtPlxLLvi26IE2e89n5RdS93fxuASyvLDr0b0yRzU8t7fgUg9/hM9iW77xT/eFI+9demolMnzXl2ZwRlBrLqUNWwdlPTZWV9fUY9ALo3lJjt38+V2YbfIgyCq9YAcPV4qXiyauDoo9rUdfpi/mlTnSOFK+M9TuWLK1beDkCDG2QHaDA7+7jt9YqEfcfW9e735NKOk8KOB84LzyMeK+RmSBRDLSfaouON9wLQ67avAXi87rHzWT9cRyKwHqqzOvyEuIyPyJccHnuSukby4yQMkPOloYxdVCEriqL4BB2QFUVRfEKJuCz2XyaFFofVkyQcqSOTAdiwrqM0ODK/Sb6Z08BzZFGhU6IUIuySIAHfH/4uDvUzR94GQKNJMr3N3ZtOWSVuvixgIp4HLl9zFQB7ZjYBINeZrddbKX1U8dP820jdabm4dWY5KXcq4y9XRX6azRBX19CrOoSOPd1gWbHeY4pT7Pbdmy4DoN5Sd4t97FPto7y+OOmcQQDM7ClbplMqHL0QWpZwF6IbviRug6WTZYNPx773hNrsa3s47Jq5l0vfFLW4geui2LRR3rv142kA5Owu/SIGqpAVRVF8gnFLKRWF6qa27WQuOvqEs3114yQJ4n/DSR7vls0JcHQQt4tbfPLRtRKuk/uxKOMG0zbI76X8lFpivyDD/maO3VIosE/KGPPsB8uttR2O3fLP90kgMS9Ea8+1sjB8sLH8KTp1F8U7f33LsGsafyjhXdWW/AJAzvYdxf5//yzF6RMo2Xsl0O4UANYOksINz5//HgA9q+wNb0fkRb3x+6XE0/vbZLNDwh1yPCftl+OyS78/kSnqvaIKWVEUxSeUzMaQoAShnXiDpPZ7lrbOa9GpxXrnJ3kti2FtSuG4G4oAak+Qrb+1nd+3PSWvLYmcqjO2k2oWHzfZUopk1WQsLZzX4hGPM7MoKcOU40IVsqIoik/QAVlRFMUn6ICsKIriE3RAVhRF8Qk6ICuKovgEHZAVRVF8gg7IiqIoPqFYO/WMMbuBzdEzxxc0s9bWK2rjctInUIx+0T6JTDnpF+2TyBSpX4o1ICuKoijRQ10WiqIoPkEHZEVRFJ+gA7KiKIpP0AFZURTFJ+iArCiK4hN0QFYURfEJOiAriqL4BB2QFUVRfIIOyIqiKD7hP4IZ9oXRrqxFAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7ff569a4a320>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_example(X_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build Neural Network with PyTorch\n",
    "Simple, fully connected neural network with one hidden layer. Input layer has 784 dimensions (28x28), hidden layer has 98 (= 784 / 8) and output layer 10 neurons, representing digits 0 - 9."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_dim = X.shape[1]\n",
    "hidden_dim = int(mnist_dim/8)\n",
    "output_dim = len(np.unique(mnist.target))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(784, 98, 10)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mnist_dim, hidden_dim, output_dim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A Neural network in PyTorch's framework."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ClassifierModule(nn.Module):\n",
    "    def __init__(\n",
    "            self,\n",
    "            input_dim=mnist_dim,\n",
    "            hidden_dim=hidden_dim,\n",
    "            output_dim=output_dim,\n",
    "            dropout=0.5,\n",
    "    ):\n",
    "        super(ClassifierModule, self).__init__()\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "        self.hidden = nn.Linear(input_dim, hidden_dim)\n",
    "        self.output = nn.Linear(hidden_dim, output_dim)\n",
    "\n",
    "    def forward(self, X, **kwargs):\n",
    "        X = F.relu(self.hidden(X))\n",
    "        X = self.dropout(X)\n",
    "        X = F.softmax(self.output(X), dim=-1)\n",
    "        return X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "skorch allows to use PyTorch's networks in the SciKit-Learn setting:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skorch import NeuralNetClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "\n",
    "net = NeuralNetClassifier(\n",
    "    ClassifierModule,\n",
    "    max_epochs=20,\n",
    "    lr=0.1,\n",
    "    device=device,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  epoch    train_loss    valid_acc    valid_loss     dur\n",
      "-------  ------------  -----------  ------------  ------\n",
      "      1        \u001b[36m0.8299\u001b[0m       \u001b[32m0.8893\u001b[0m        \u001b[35m0.4037\u001b[0m  0.8951\n",
      "      2        \u001b[36m0.4331\u001b[0m       \u001b[32m0.9113\u001b[0m        \u001b[35m0.3075\u001b[0m  0.7862\n",
      "      3        \u001b[36m0.3619\u001b[0m       \u001b[32m0.9240\u001b[0m        \u001b[35m0.2614\u001b[0m  0.8314\n",
      "      4        \u001b[36m0.3237\u001b[0m       \u001b[32m0.9305\u001b[0m        \u001b[35m0.2379\u001b[0m  0.7739\n",
      "      5        \u001b[36m0.2914\u001b[0m       \u001b[32m0.9371\u001b[0m        \u001b[35m0.2173\u001b[0m  0.7721\n",
      "      6        \u001b[36m0.2739\u001b[0m       \u001b[32m0.9413\u001b[0m        \u001b[35m0.1979\u001b[0m  0.7949\n",
      "      7        \u001b[36m0.2569\u001b[0m       \u001b[32m0.9449\u001b[0m        \u001b[35m0.1859\u001b[0m  0.7691\n",
      "      8        \u001b[36m0.2420\u001b[0m       \u001b[32m0.9461\u001b[0m        \u001b[35m0.1813\u001b[0m  0.7871\n",
      "      9        \u001b[36m0.2337\u001b[0m       \u001b[32m0.9496\u001b[0m        \u001b[35m0.1708\u001b[0m  0.7730\n",
      "     10        \u001b[36m0.2195\u001b[0m       \u001b[32m0.9532\u001b[0m        \u001b[35m0.1604\u001b[0m  0.7992\n",
      "     11        \u001b[36m0.2151\u001b[0m       \u001b[32m0.9547\u001b[0m        \u001b[35m0.1514\u001b[0m  0.7955\n",
      "     12        \u001b[36m0.2065\u001b[0m       \u001b[32m0.9560\u001b[0m        \u001b[35m0.1476\u001b[0m  0.7734\n",
      "     13        \u001b[36m0.2015\u001b[0m       \u001b[32m0.9563\u001b[0m        \u001b[35m0.1455\u001b[0m  0.8316\n",
      "     14        \u001b[36m0.1943\u001b[0m       \u001b[32m0.9587\u001b[0m        \u001b[35m0.1389\u001b[0m  0.7913\n",
      "     15        \u001b[36m0.1883\u001b[0m       0.9578        \u001b[35m0.1381\u001b[0m  0.8092\n",
      "     16        \u001b[36m0.1848\u001b[0m       \u001b[32m0.9596\u001b[0m        \u001b[35m0.1323\u001b[0m  0.7907\n",
      "     17        \u001b[36m0.1838\u001b[0m       \u001b[32m0.9606\u001b[0m        \u001b[35m0.1312\u001b[0m  0.7806\n",
      "     18        \u001b[36m0.1776\u001b[0m       \u001b[32m0.9623\u001b[0m        \u001b[35m0.1261\u001b[0m  0.7657\n",
      "     19        \u001b[36m0.1738\u001b[0m       \u001b[32m0.9625\u001b[0m        \u001b[35m0.1250\u001b[0m  0.7690\n",
      "     20        \u001b[36m0.1704\u001b[0m       \u001b[32m0.9627\u001b[0m        \u001b[35m0.1238\u001b[0m  0.7627\n"
     ]
    }
   ],
   "source": [
    "net.fit(X_train, y_train);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = net.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9627428571428571"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy_score(y_test, y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "An accuracy of about 96% for a network with only one hidden layer is not too bad.\n",
    "\n",
    "Let's take a look at some predictions that went wrong:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "error_mask = y_pred != y_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAABbCAYAAABEQP/sAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAEDZJREFUeJzt3Xt0VdWdwPHvTgghECAQIBAICa9EZWDogIgwPByqBRwo+AAUC8uqVFEGmap1zWrJKKuLAjXI8BirVCpYljgVqOIDeRRokVcUhRaCPAUUCCKihGeSPX/87km8l0tMIMnZuff3WYt1c889uex7uGfzO7/z23sbay1KKaX8F+N3A5RSSgntkJVSyhHaISullCO0Q1ZKKUdoh6yUUo7QDlkppRyhHbJSSjnCyQ7ZGHO9MWaNMea0MWavMWaY321ygTEmwxjzjjHmlDHmmDFmtjGmlt/t8osx5kzInyJjzCy/2+UCY8xIY8wuY0yBMWafMaa3323yW004f5zrkAMH6M/AcqAxMBZ41RiT6WvD3DAXyAdaAF2AvsA4X1vkI2ttovcHSAHOAf/nc7N8Z4y5FZgK3A/UB/oA+31tlBucP3+c65CB64BUYIa1tshauwbYAPzE32Y5oQ3wurX2vLX2GPAe0NHnNrniLuRk+6vfDXHAM8Cz1tpN1tpia+3n1trP/W6UA5w/f1zskM0Vtv1TdTfEQTOBkcaYusaYlsBA5EulYAywwEb5XADGmFigG9A0kO47Erg0T/C7bQ5w/vxxsUPOQyKdJ40xccaY25BLi7r+NssJ65D/0b8BjgC5wDJfW+QAY0xr5Dvyit9tcUAKEIdcMfRGLs1/APzSz0Y5wvnzx7kO2Vp7CRgK3A4cA34OvI4cwKhljIkBVgBLgHpAE6ARkiuMdqOBv1lrD/jdEAecCzzOstYetdZ+CeQAg3xsk+9qyvnjXIcMYK3dbq3ta61Nttb+CGgLbPG7XT5rDKQBs621F6y1J4H5RPmJFjAajY4BsNaeQoKXqE7dhFEjzh8nO2RjTGdjTJ1ArucJ5K7oH3xulq8Ckc4B4BFjTC1jTBKSN/3E35b5yxjTE2iJVld813xgvDGmmTGmEfA4UrUUtWrK+eNkh4xUVBxFcsn9gVuttRf8bZIT7gAGACeAvUAhMNHXFvlvDLDEWvut3w1xyGRgK/ApsAvYBvza1xa5wfnzx0T5TWmllHKGqxGyUkpFHe2QlVLKEdohK6WUI7RDVkopR2iHrJRSjqjQ1HO1TbytQ72qaosTzlPARXsh3HwaYUXDMQH4llNfWmublmdfPSbhRcNx0fMnvPJ+VyrUIdehHjeZ/lffqhpgs11dof2j4ZgArLJ/+qy8++oxCS8ajoueP+GV97uiKQullHKEdshKKeUI7ZCVUsoRTq0npZSqXHtm9gBg/90vANDzPx8GoP5rm3xrk7oyjZCVUsoRGiErFYFiOl8HwPSBiwAossV+NkeVk0bISinlCI2QHRCb0gyAw2PaA3Cuy1kAdvd9WV438v/m90U5WWsfACBtQSwAtVfkVn5jXdG9EwAFzxYA8JdOMj/9zZMeAyD59xv9aZfPYv75egDuWbwSgKH1vgZg4tGbAGjwxkeALifiKo2QlVLKERETIadsbADAhk03ANB+ott3kT//Rc+Sn+8dJaObnkx+N2gfLx5ee04i3o0F14V9r9a1vwRgV795AKzvURuAyeN/CkD8O1srp9EOOPfj7gBMmP4aADn7fwhA1hvjAEiO8tBv90NyHoyqnx+0fcuMrgAk9roIwMHB8h1p93O3z5PKUCujNQBf9UwFIH+QLD7UvMlpADZ0XgKUXoEu/LY5AJNzbwegyft1AGi0+KOS97QXqmYBI42QlVLKETU2Qj47THJif53zu6Dt7QIRsuvWPja95OeGMXXC7pO1ciwAbRbK81qrPwy737cj7gLgnpw5APSpI1HQ0/+zAIBpF+8DIG5V+N+vCU4+cDMAm58NfMYd8pkTJ8nENB22bPanYY55pN+qoOcD8n4MQNKnkmu/2FAi449HPA9A/49lSbmkhZGXcy+4U/qIO56RfPr4RkvD7ncp5KpqVP2j8niLXHFyizxk9vtZyT6ZP62a+zMaISullCO0Q1ZKKUf4nrLwUg9tntoFXH5Tbu+MHkH77xvxQuCnj6ungZXs0m3dAIgzl18i/jJfbrzsuE2mTe1wMvAZi4uu6u/qnyDlc09fHw9Ayqqy9nZUoLxtSbakePrsGA1Aw+FyI7Pom/3lehsv5eGJtLK40/fJeTI2aUZgi/ybX3pOblDFbJUbu/GBASPrzicBcOJGuZGVtLC6Wlr14ta2AOD99rPkuYkNu9+EL3oBsObdHwDQpPtxAO5P/wCAe+ofCvr9pbfMKfnd/2o0AICiU6cqte0aISullCN8j5BDb8qRvl4eR3gbyhcJj/6sDwCp692ue4o/LjdXiu132hlYX2HZconiMk5ULHprNX5P2O3D9vw7AC3fPAxAYYXe1Q37H5fopEVsAgCJAyQiLu81w5CdJwEY23A2AG+fbQjASyv6AlB45PPKaqqvuk3YBkCiiS9zv+LteQDkHLwNgKd+uByApZR74RNnxWa2AyCnjdzMjjNys/zF0xkA/O+CwQAk7ZFvT7035EZwOsHn2+vIVcXcnw0DYNMk+e50rF3aXebfKVcayfMq90pLI2SllHKE7xFyZTl+8zcA1MXt8qfiTyRX3if3wZJtH3W/ugTembsl/z4rLSewRSKCMQdlsIQdKfnBwmOHr+r9XeBdRxR7w2QCOWW27Ai7f620VgDEvirXA2Mbfhj0+7+e8hMAGh+JrBxy58Tgf+NNgXELdU6cB6JjqPSuickAtKkVXEb61mi5GmqZ+0GF3i/lb18FPfeurgCaLvoEKB28VVk0QlZKKUf4HiH3flSKrb/oE7xQba8eO4HSqovS6grh5YwPTJPJVMqKjEMrNVwYVt3qntLqgKxpMuy3afhU8GVqpacBsDRHIuNGIQNLNu7oAEDmsS3X2kzfZf1CqimeeVMqUN5Z+gpQOjDk3tZSPbDo0I0AHN0judDd7ecCkJ0vd9A/GSzHLNIi4yuZckiG/dqt4a8kIlF88rmw22POykCpq6tVKvWrF0eX/Jx6tmLRdnlphKyUUo7wPUKuu1Qi2/YhoxqPBx5ThwWyXyOCXy+pV14aHO2G1jUD7A3s4lIFRvH58yU/d/iPiuW9dz4tdZahkfGRQokQsn4nlRzufNqrV3j4CFAa4faZJ1OU3p0mE73M2Cb58mZvBqoLesmn9nLG3u9FSjVFZUuNkzraWm2kPr7wQLlWq3dS4spE+aFX8PYTN0luufHOir2fOSlTl04/KX1N6z+VfoeqqmJJI2SllHKE7xHy9wnNLXtKcsyB/HDoCD4vxwxu5IwrU9sOx8JuH7j5EQDSt0Ve3tCLcBNlgBTvIiPN2rEtaL+Y+9vKYyDWiNbIOO+LFADacbTM/W6vewaA7NtbAtBsds2NkJutPwHAgUK5+vSqLUY9IdPavho7ECh/7XDhMblOX9c5IbCl6o+NRshKKeUIZyPkvZdFvsEWeCP6vMeA0OoLcL82ubwG/kNyWvc22BDYEpxDLsmhRSFvroqNnWRUVZ/twwFowD7f2lQdCu6SeyaD6j0X2FIXgPR54edviGRFu/cCMGL6kwA8MX4xAI8myXfgjkkyH8rQIbLUWcpwiXgLu8moOxsn8WnsOqkxvto5ZK6FRshKKeWIKo2QvYqH7+aBrxTxXq5is7l59cxe1UakRMUAscmNAeiRIKPOQqsr7v+sPwApqyVfWBPnrLhWp7Pk0RtN1fghySNG+rFosF3qtHdebARAi4SylxaKbSBLPKUEFj9deU7yoy3el3xp9ceEla/ZbKkRnpt/NwBf//fbAIxteBCAzV0XyY4lF08b+K7+4+ReTMKfq7+OXyNkpZRyRJVGyJfN5FaJ2i1+GCitLfYi40iUly0j77rGrwza/vBhGaP/1TBZlqfo+MFqbZcLvLkrfjP0jwBsOSNVFtFSXVH0qYR5By8FZmtLOFLm/ue7y3dpfuuXAHjpdFrQ+0SSxNeluuqdXBnlOXOszH64dKSMcM2Mq+1Pw8qgEbJSSjmi2qssvCqIDSGLkX5fbtmLiL2a4vZEVm1xON78rm8PCV4F4nSx5EdzX+sMQPPjVTOuvibYPbUJAEPqyYizKVNHAZBMdMxZcSXnk+XUjt66m1KF+w8C0OZpeRwSLwu75g2fE7TfDeuk+qLtW/4tBqwRslJKOaJKI+QfpXYJs1XmLb4swh0RZldKqydC56yIBnkTJPprHxe8CsTfL9YHoPnz0RsZh/Lmroi0tfKuVpuJuwE4+VagIidGYq8DI6I3BvNmSZwzeH7Y1ztMljlginyoP/ZE77+OUko5xveReqVzFQfXHXvRdSTVE1dUl07hV1Qet+1eANL4e3U2x0kTu6wGoOPascDlc1tEq1fS1wDQ/9/kCvNSoozc2ztobtB+U9fLvMmZ1Py5s68oRj573uMyX4e3Gruny5zxALTK8//qSiNkpZRyhO8R8pVWAvFyzdHoUHZPAFZkTAtskdFUow7ISsEZD0mNbSSMqrpW3uirGWXvFvGmbBgEwAMDXwzafvZRGZGXnfVW0PY3CmRkXyTNnX0l+eNkxHDe8FlB2+d/Iznl9JdkqZ4i6/9R0AhZKaUc4VuE7M1zEZo7LlkJJArqjEPlPyqR8QcP/haAxJiEoNf3LcgEoNnF7QDE1KtX5vvZi5cAMLXjACguKKi8xjrCG604oYvkTN9Lk1WpvZVGokXWC1KbvrKffGduTZDVYzZ2WRy0X3EgFn5ustyHSNrmf960qhW0Ch/5Pr9oKABpJ9ypVtIIWSmlHKEdslJKOcK3lEXo0kyhQ6Oj0dedJcWQGBMf9vVN2TL5Otnle79uW+8DYH23lwEY/NiEktcSlkVGmdO6/bLo6dy+fwEgJ/tWADIfjK6Uhc2VEsjnR8qUkzGvSaqif8h0nFnLxgHQYWHkpyo83fvsCnp+wcp5lrHkJODWzXGNkJVSyhG+l72pUjdMlUUajwyQGzKtaiWUtTurz8lyPRM+lHHnRYVSAL+r3zwAcm98FYBLVrYf71q6rE/Gsspqtb+SVssxWt5Nlno/MEg+e6/37gAgcVLgxueWyFv4NRwvUn6ufUd5DHm9QxQNtDJd5Rg8lfr7wBa5uf11sSxbUPSP3X40q0waISullCM0QnaIN03g0JynABjz0HsAjG+0J2i/zPdlmHDbhfI8Y41MFxhTVyLmfxknQ0HPtJNIILaB5Mza/Sry8obeZEIvr+gNwBcrJEJc00lyqHdOGwKAvU+GzUbLxPUKLvzmDADXx8UFbe+96nEAMsmt9jZ9H42QlVLKEb5FyCXVFCNCniuaz5RC9RUzZUHKFXQNej2T8BNoF5+VSVNSf+tOoXt18SLf5R1lSPBybgy8ctSnFim/5a+Rq6JDmXJPZt8l+W403hx3xd/xm0bISinlCN9zyOEnsVdKqWvTaopcKY6b8q9B25s4vLyXRshKKeUI7ZCVUsoR2iErpZQjjK3ApMzGmBPAZ1XXHCekW2ublnfnKDkmUIHjosckvCg5LnpMwivXcalQh6yUUqrqaMpCKaUcoR2yUko5QjtkpZRyhHbISinlCO2QlVLKEdohK6WUI7RDVkopR2iHrJRSjtAOWSmlHPH/JpCSt/m0iD8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7ff5140a62b0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_example(X_test[error_mask], y_pred[error_mask])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convolutional Network\n",
    "PyTorch expects a 4 dimensional tensor as input for its 2D convolution layer. The dimensions represent:\n",
    "* Batch size\n",
    "* Number of channel\n",
    "* Height\n",
    "* Width\n",
    "\n",
    "As initial batch size the number of examples needs to be provided. MNIST data has only one channel. As stated above, each MNIST vector represents a 28x28 pixel image. Hence, the resulting shape for PyTorch tensor needs to be (x, 1, 28, 28). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "XCnn = X.reshape(-1, 1, 28, 28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(70000, 1, 28, 28)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "XCnn.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "XCnn_train, XCnn_test, y_train, y_test = train_test_split(XCnn, y, test_size=0.25, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((52500, 1, 28, 28), (52500,))"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "XCnn_train.shape, y_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Cnn(nn.Module):\n",
    "    def __init__(self, dropout=0.5):\n",
    "        super(Cnn, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)\n",
    "        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)\n",
    "        self.conv2_drop = nn.Dropout2d(p=dropout)\n",
    "        self.fc1 = nn.Linear(1600, 100) # 1600 = number channels * width * height\n",
    "        self.fc2 = nn.Linear(100, 10)\n",
    "        self.fc1_drop = nn.Dropout(p=dropout)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.relu(F.max_pool2d(self.conv1(x), 2))\n",
    "        x = torch.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n",
    "        \n",
    "        # flatten over channel, height and width = 1600\n",
    "        x = x.view(-1, x.size(1) * x.size(2) * x.size(3))\n",
    "        \n",
    "        x = torch.relu(self.fc1_drop(self.fc1(x)))\n",
    "        x = torch.softmax(self.fc2(x), dim=-1)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "\n",
    "cnn = NeuralNetClassifier(\n",
    "    Cnn,\n",
    "    max_epochs=10,\n",
    "    lr=0.002,\n",
    "    optimizer=torch.optim.Adam,\n",
    "    device=device,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  epoch    train_loss    valid_acc    valid_loss     dur\n",
      "-------  ------------  -----------  ------------  ------\n",
      "      1        \u001b[36m0.4298\u001b[0m       \u001b[32m0.9729\u001b[0m        \u001b[35m0.0898\u001b[0m  4.9916\n",
      "      2        \u001b[36m0.1577\u001b[0m       \u001b[32m0.9799\u001b[0m        \u001b[35m0.0646\u001b[0m  4.8633\n",
      "      3        \u001b[36m0.1261\u001b[0m       \u001b[32m0.9824\u001b[0m        \u001b[35m0.0564\u001b[0m  4.9402\n",
      "      4        \u001b[36m0.1120\u001b[0m       \u001b[32m0.9848\u001b[0m        \u001b[35m0.0507\u001b[0m  4.8320\n",
      "      5        \u001b[36m0.1006\u001b[0m       \u001b[32m0.9855\u001b[0m        \u001b[35m0.0446\u001b[0m  4.8227\n",
      "      6        \u001b[36m0.0924\u001b[0m       \u001b[32m0.9862\u001b[0m        \u001b[35m0.0415\u001b[0m  4.8191\n",
      "      7        \u001b[36m0.0844\u001b[0m       \u001b[32m0.9886\u001b[0m        \u001b[35m0.0375\u001b[0m  4.8168\n",
      "      8        \u001b[36m0.0828\u001b[0m       0.9854        0.0414  4.8333\n",
      "      9        \u001b[36m0.0779\u001b[0m       0.9885        \u001b[35m0.0368\u001b[0m  4.8199\n",
      "     10        \u001b[36m0.0768\u001b[0m       \u001b[32m0.9891\u001b[0m        \u001b[35m0.0350\u001b[0m  4.8129\n"
     ]
    }
   ],
   "source": [
    "cnn.fit(XCnn_train, y_train);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred_cnn = cnn.predict(XCnn_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9889714285714286"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy_score(y_test, y_pred_cnn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "An accuracy of >98% should suffice for this example!\n",
    "\n",
    "Let's see how we fare on the examples that went wrong before:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7745398773006135"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy_score(y_test[error_mask], y_pred_cnn[error_mask])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Over 70% of the previously misclassified images are now correctly identified."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAABbCAYAAABEQP/sAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAEJVJREFUeJzt3Xl01FWWwPHvy0IICWuQfQlbIjLQzICIMAIOrQ04IIiCgo3HVlFRBplW2zNnlG49c2igBRkBbUVtwaaFaYFWXJClgRZZRBFoIQiyt2yiIoQ1yZs/bv0SqihiAkneo+p+zvFU8uqX8uZH1cv9vd997xlrLUoppdxLcB2AUkopoR2yUkp5QjtkpZTyhHbISinlCe2QlVLKE9ohK6WUJ7RDVkopT3jZIRtjjkf8l2+Med51XC4ZY1KMMa8YY3YbY44ZY9YbY3q7jss1Y0ymMeY9Y8x3xpgDxpgpxpgk13G5ZIypZYyZZ4zJDb1fhriOyQfGmNbGmKXGmKPGmO3GmAGuY4rkZYdsrU0P/gPqAieB/3MclmtJwF6gO1AdeBKYY4zJdBiTD6YBh4D6QHvk/IxwGpF7U4EzyGdnKPCCMaaN25DcCv2R/guwAKgFDAfeMMZkOQ0sgpcdcoRbkQ/c31wH4pK1Ntda+2tr7S5rbYG1dgGwE+jgOjbHmgFzrLWnrLUHgA+AuO18jDFpwEDgSWvtcWvtR8DbwM/dRubclUADYJK1Nt9auxRYiWfn5XLokO8CZlid4x3GGFMXyAK+cB2LY5OB240xVYwxDYHeSKccr7KAfGvtl+e0bSCO/0iFmAu0/VNFB1IcrztkY0wT5BL0ddex+MQYkwz8EXjdWpvjOh7HliOdzQ/APmAdMN9pRG6lA0cj2o4CVR3E4pMc5Er7MWNMsjHmRqRvqeI2rHBed8jAMOAja+1O14H4whiTAMxExggfdhyOU6FzsRCYC6QBtYGawDiXcTl2HKgW0VYNOOYgFm9Ya88C/YGbgAPAL4E5yB9xb1wOHbJmxyHGGAO8gtysGRh6k8WzWkBjYIq19rS19gjwGtDHbVhOfQkkGWNandP2E3RoC2vtRmttd2tthrX2Z0BzYK3ruM7lbYdsjOkCNESrK871AtAa6GutPek6GNestd8gNzYfNMYkGWNqIPccNriNzB1rbS5yxfC0MSbNGNMVuBm5qoprxph2xpjKofsNjyKVOX9wHFYYbztk5IM111ob15daAWNMU+B+pLTrwDk12kMdh+baLUAv4DCwHcgDRjuNyL0RQCoyZvon4EFrbdxnyEhFxX7kvPQEbrDWnnYbUjijxQtKKeUHnzNkpZSKK9ohK6WUJ7RDVkopT2iHrJRSntAOWSmlPFGqZQormRRbmbTyisULp8jljD0dbd57VPFwTgCO8d031torSnKsnpPo4uG86OcnupK+V0rVIVcmjWtMz4uP6jKwxi4p1fHxcE4AFts/7y7psXpOoouH86Kfn+hK+l7RIQullPKEdshKKeUJ7ZCVUsoTcb33mFKxbtvkzgDsuO1FALr85wMAVH1ztbOY1IVphqyUUp7QDFmpGJTQ7koAJvSeBUC+LXAZjiohzZCVUsoTmiF7ILFuHQD23tUSgJPtTwCwtfur8ryRv5s/luVkL7sHgMYzEgGotHBd2Qfri05tAch9OheAv7aVfQyufUp2tcp4ZZWbuBxL+ElrAO6YvQiA/mnfAzB6/zUAVHvrMwB00V0/aYaslFKeiJkMue4q2ddx5eqrAGg52u+7yP/4VZfCr4cMldlNj2W8H3ZMkA8vOykZ76rcK6O+VpNK3wCwpcd0AFZ0rgTAMyN/AUDKe5+UTdAeOHlzJwBGTXgTgIk7fgpA9lsjAMiI89Rv633yORha9VBY+9pJHQBI73oGgF195T3S4pd+f07KQlJmEwC+7dIAgEN9ZJOQerVlc+6V7eYCRVegM4/VA+CZdTcBUPvDygDUnP1Z4Wva0+Wz0YhmyEop5YnLNkM+MUDGxP429fdh7S1CGbLvlj08ofDr6gmVox6TvWg4AM1C21MmLfk06nHHBt8KwB0TpwLQrbJkQU/87wwAxp+5E4DkxdF//nJw5J5rAVjzdOh33CS/c/pTsjBNq7Vr3ATmmQd7LA77vlfOzQDU+FLG2s9Ul8z488HPAdDzc9l+sMbM2Btzzx0ofcQtv5Hx9JE150U97mzEVdXQqvvl8Xq54uR6ecjqcX/hMVm/KJ/7M5ohK6WUJ7RDVkopTzgfsgiGHpo9vgU4/6bc9kmdw47/avCLoa8+r5gAy9jZGzsCkGzOv0T870Ny42XTjbJsaqsjod+xIP+i/l89U6V87onWKQDUXVzc0Z4KlbfNHSNDPN02DQOg+iC5kZn/w44SvUww5BGItbK4o3fK52R4jUmhFvk3P/us3KBK+ERu7KaEJowsP1UDgMNXy42sGjMrKtLyl7ysPgAftnxevjeJUY8b9XVXAJa+/88A1O50EIC7m34MwB1V94T9/Lzrpxb+7H/V7AVA/nfflWnsmiErpZQnnGfIkTflaLpCHgcHDSXLhIft7gZAgxV+1z2lHJSbKwX2nDhD+yvMXyBZXObh0mVvjUZui9o+YNu/A9Dw7b0A5JXqVf2w4xHJTuonpgKQ3ksy4pJeM/TbfASA4dWnAPDuieoAvLywOwB5+/5RVqE61XHUegDSTUqxxxVszAFg4q4bAXj8pwsAmEeJNz7xVmJWCwAmNpOb2clGbpa/dDQTgBdm9AWgxjZ596S9JTeCmxL+eZuDXFVMu38AAKufkvdOm0pF3eWhgXKlkTG9bK+0NENWSilPOM+Qy8rBa38AoAp+lz8VbJCx8m7r7i1s+6zTxQ3gHb9Nxt+fbzwx1CIZwV27ZLKEvV3GB/MO7L2o1/dBcB1REEyTCY0ps3ZT1OOTGjcCIPENuR4YXv3TsJ//n7E/B6DWvtgaQ26XHv5vvDo0b6Hy4VNAfEyV3jI6A4BmSeFlpO8Mk6uhhus+LtXr1f3o27Dvg6srgCtmbQCKJm+VFc2QlVLKE84z5OsekmLrr7uFb1TbtfNmoKjqoqi6QgRjxjvHy2IqxWXGkZUaPkyrbnRHUXVA9niZ9ntF9KHg8yQ1bQzAvImSGdeMmFiyalMrALIOrL3UMJ3L/pVUU/zmbalAeW/e60DRxJAhTaR6YNaeqwHYv03GQre2nAbAmENyB31DXzlnsZYZX8jYPTLt134S/UoiFqVknIzannBCJkpdXK1SkSdfGlb4dYMTpcu2S0ozZKWU8oTzDLnKPMlsW0bMajwYemwwIDT6NTj8+cJ65Xnh2W5kXTPA9tAhPlVgFJw6Vfh1q/8o3bj35iekzjIyM96XJxlC9u+lksOf3/bi5e3dBxRluN2myxKltzWWhV4mrZfx8jpvh6oLuspvHYwZBz8XK9UUZa1BstTRJjWT+vi8nSXard5L6YvS5Yuu4e2Hr5Gx5VqbS/d65ogsXTrhiPQ1Tf5c9B4qr4olzZCVUsoTzjPkHxM5thwoHGMOjQ9HzuALxpjBjzHjstS81YGo7b3XPAhA0/WxN24YZLjpMkGK95GZZi1YH3Zcwt3N5TGUa8RrZpzzdV0AWrC/2ONuqnIcgDE3NQSgzpTLN0Ous+IwADvz5OozqLYY+qgsa/tGYm+g5LXDeQfkOn15u9RQS/mfG82QlVLKE95myNvPy3zDzQhm9AWPIZHVF+B/bXJJ9f5CxrSGVFsZagkfQy4cQ4tDwVoVq9rKrKpuGwcBUI2vnMVUEXJvlXsmfdKeDbVUAaDp9OjrN8Sy/K3bARg84TEAHh05G4CHash74JanZD2U/v1kq7O6gyTjzesos+5ssuSniculxvhi15C5FJohK6WUJ8o1Qw4qHs4dB75Qxnu+0q3mFtQzB1UbsZIVAyRm1AKgc6rMOousrrh7d08A6i6R8cLLcc2KS3U0Wx6D2VS17pNxxFg/F9U2Sp325jM1AaifWvzWQonVZIunuqHNTxedlPHR+h/KeGnF54Rlr84UqRGedug2AL7/9bsADK++C4A1HWbJgYUXTys5V88Rci8m9S8VX8evGbJSSnmiXDPk81ZyK0MtZj8AFNUWB5lxLMoZIzPvOqQsCmt/YK/M0f92gGzLk39wV4XG5YNg7Yrf9v8jAGuPS5VFvFRX5H8pad6us6HV2lL3FXv8qU7yXnqtycsAvHy0cdjrxJL0OVJd9d46meU5ebisfjjvdpnhmpVcyU1gxdAMWSmlPFHhVRZBFcTKiM1If2xsOciIg5rilsRWbXE0wfqu7/YL3wXiaIGMj657sx0A9Q6Wz7z6y8HWcbUB6JcmM87GjhsKQAbxsWbFhZzKkI92/NbdFMnbsQuAZk/IY78U2dg1Z9DUsOOuWi7VF83fcbcZsGbISinliXLNkH/WoH2UVlm3+LwMd3CUQymqnohcsyIe5IyS7K9lcvguEH8/UxWAes/Fb2YcKVi7Itb2yrtYzUZvBeDIO6GKnATJvXYOjt8cLFglcWrf16I+3+oZWQMm30H9cSB+/3WUUsozzmfqFa1VHF53HGTXsVRPXFrt20bfUXnE+iEANObvFRmOl0a3XwJAm2XDgfPXtohXrzddCkDPf5MrzLPpMnNve59pYceNWyHrJmdx+a+dfUEJ8rvnPCLrdQS7sQfaTx0JQKMc91dXmiErpZQnnGfIF9oJJBhrjkd7xnQBYGHm+FCLzKYaulN2Cs68T2psY2FW1aUKZl9NKv6wmDd2ZR8A7un9Ulj7iYdkRt6Y7HfC2t/KlZl9sbR29oUcGiEzhnMGPR/W/toPMqbc9GXZqiffuj8LmiErpZQnnGXIwToXkWPHhTuBxEGdcaRDD0lm/PG9vwMgPSE17PmvZmQBUOfMRgAS0tKKfT175iwAplIyAAW5uWUXrCeC2Yqj2suY6QeNZVfqYKeReJH9otSmL+oh75kbUmX3mFXtZ4cdVxDKhZ99Ru5D1Fjvfty0vOU2ip75PjerPwCND/tTraQZslJKeUI7ZKWU8oSzIYvIrZkip0bHo+/byRBDekJK1OdXj5HF1xlTstfr+MmdAKzo+CoAfR8eVfhc6vzYKHNavkM2PZ3W/a8ATBxzAwBZ98bXkIVdJyWQz90uS04mvClDFT0jluPMnj8CgFYzY3+oItCp25aw709b+Zxlzj0C+HVzXDNkpZTyhPOyN1XkqnGySeO+XnJDplFSanGHs+SkbNcz6lOZd56fJwXwW3pMB2Dd1W8AcNZK+8EORdv6ZM4vq6jdqrFEztGCjrLV+84+8rt3/eAWANKfCt34XBt7G79GE2TKz7ZsI48Rz7eKo4lWpoOcg8cbvBJqkZvb3xfItgX5X2x1EVaxNENWSilPaIbskWCZwP4THwfgrvs+AGBkzW1hx2V9KNOEm8+U7zOXynKBCVUkY/6XETIV9HgLyQQSq8mYWYsnY2/cMFhM6NWF1wHw9ULJEJe2lTHUgeP7AWDvlGmz8bJwvYLTvz0OQOvk5LD26xY/AkAW6yo8ph+jGbJSSnnCWYZcWE0xOOJ7Rb3JUqi+cLJsSLmQDmHPZxF9Ae2CE7JoSoPf+VPoXlGCzHdBG5kSvICrQ8/sdxSRcu3QUrkq2pMl92S+OivvjVprki/4M65phqyUUp5wPoYcfRF7pZS6NI3GypXiiLH/GtZe2+PtvTRDVkopT2iHrJRSntAOWSmlPGFsKRZlNsYcBnaXXzheaGqtvaKkB8fJOYFSnBc9J9HFyXnRcxJdic5LqTpkpZRS5UeHLJRSyhPaISullCe0Q1ZKKU9oh6yUUp7QDlkppTyhHbJSSnlCO2SllPKEdshKKeUJ7ZCVUsoT/w/7t45ePF+BTAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7ff51229f7f0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_example(X_test[error_mask], y_pred_cnn[error_mask])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
